{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "# Learning Bayesian Networks\n",
    "\n",
    "\n",
    "Previous notebooks showed how Bayesian networks economically encode a probability distribution over a set of variables, and how they can be used e.g. to predict variable states, or to generate new samples from the joint distribution. This section will be about obtaining a Bayesian network, given a set of sample data. Learning a Bayesian network can be split into two problems:\n",
    "\n",
    " **Parameter learning:** Given a set of data samples and a DAG that captures the dependencies between the variables, estimate the (conditional) probability distributions of the individual variables.\n",
    " \n",
    " **Structure learning:** Given a set of data samples, estimate a DAG that captures the dependencies between the variables.\n",
    " \n",
    "This notebook aims to illustrate how parameter learning and structure learning can be done with pgmpy.\n",
    "Currently, the library supports:\n",
    " - Parameter learning for *discrete* nodes:\n",
    "   - Maximum Likelihood Estimation\n",
    "   - Bayesian Estimation\n",
    " - Structure learning for *discrete*, *fully observed* networks:\n",
    "   - Score-based structure estimation (BIC/BDeu/K2 score; exhaustive search, hill climb/tabu search)\n",
    "   - Constraint-based structure estimation (PC)\n",
    "   - Hybrid structure estimation (MMHC)\n",
    "\n",
    "\n",
    "## Parameter Learning\n",
    "\n",
    "Suppose we have the following data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     fruit tasty   size\n",
      "0   banana   yes  large\n",
      "1    apple    no  large\n",
      "2   banana   yes  large\n",
      "3    apple   yes  small\n",
      "4   banana   yes  large\n",
      "5    apple   yes  large\n",
      "6   banana   yes  large\n",
      "7    apple   yes  small\n",
      "8    apple   yes  large\n",
      "9    apple   yes  large\n",
      "10  banana   yes  large\n",
      "11  banana    no  large\n",
      "12   apple    no  small\n",
      "13  banana    no  small\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "data = pd.DataFrame(data={'fruit': [\"banana\", \"apple\", \"banana\", \"apple\", \"banana\",\"apple\", \"banana\", \n",
    "                                    \"apple\", \"apple\", \"apple\", \"banana\", \"banana\", \"apple\", \"banana\",], \n",
    "                          'tasty': [\"yes\", \"no\", \"yes\", \"yes\", \"yes\", \"yes\", \"yes\", \n",
    "                                    \"yes\", \"yes\", \"yes\", \"yes\", \"no\", \"no\", \"no\"], \n",
    "                          'size': [\"large\", \"large\", \"large\", \"small\", \"large\", \"large\", \"large\",\n",
    "                                    \"small\", \"large\", \"large\", \"large\", \"large\", \"small\", \"small\"]})\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We know that the variables relate as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pgmpy.models import BayesianModel\n",
    "\n",
    "model = BayesianModel([('fruit', 'tasty'), ('size', 'tasty')])  # fruit -> tasty <- size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Parameter learning is the task to estimate the values of the conditional probability distributions (CPDs), for the variables `fruit`, `size`, and `tasty`. \n",
    "\n",
    "#### State counts\n",
    "To make sense of the given data, we can start by counting how often each state of the variable occurs. If the variable is dependent on parents, the counts are done conditionally on the parents states, i.e. for seperately for each parent configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "         fruit\n",
      "apple       7\n",
      "banana      7\n",
      "\n",
      " fruit apple       banana      \n",
      "size  large small  large small\n",
      "tasty                         \n",
      "no      1.0   1.0    1.0   1.0\n",
      "yes     3.0   2.0    5.0   0.0\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.estimators import ParameterEstimator\n",
    "pe = ParameterEstimator(model, data)\n",
    "print(\"\\n\", pe.state_counts('fruit'))  # unconditional\n",
    "print(\"\\n\", pe.state_counts('tasty'))  # conditional on fruit and size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see, for example, that as many apples as bananas were observed and that `5` large bananas were tasty, while only `1` was not.\n",
    "\n",
    "#### Maximum Likelihood Estimation\n",
    "\n",
    "A natural estimate for the CPDs is to simply use the *relative frequencies*, with which the variable states have occured. We observed `7 apples` among a total of `14 fruits`, so we might guess that about `50%` of `fruits` are `apples`.\n",
    "\n",
    "This approach is *Maximum Likelihood Estimation (MLE)*. According to MLE, we should fill the CPDs in such a way, that $P(\\text{data}|\\text{model})$ is maximal. This is achieved when using the *relative frequencies*. See [1], section 17.1 for an introduction to ML parameter estimation. pgmpy supports MLE as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------------+-----+\n",
      "| fruit(apple)  | 0.5 |\n",
      "+---------------+-----+\n",
      "| fruit(banana) | 0.5 |\n",
      "+---------------+-----+\n",
      "+------------+--------------+--------------------+---------------------+---------------+\n",
      "| fruit      | fruit(apple) | fruit(apple)       | fruit(banana)       | fruit(banana) |\n",
      "+------------+--------------+--------------------+---------------------+---------------+\n",
      "| size       | size(large)  | size(small)        | size(large)         | size(small)   |\n",
      "+------------+--------------+--------------------+---------------------+---------------+\n",
      "| tasty(no)  | 0.25         | 0.3333333333333333 | 0.16666666666666666 | 1.0           |\n",
      "+------------+--------------+--------------------+---------------------+---------------+\n",
      "| tasty(yes) | 0.75         | 0.6666666666666666 | 0.8333333333333334  | 0.0           |\n",
      "+------------+--------------+--------------------+---------------------+---------------+\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.estimators import MaximumLikelihoodEstimator\n",
    "mle = MaximumLikelihoodEstimator(model, data)\n",
    "print(mle.estimate_cpd('fruit'))  # unconditional\n",
    "print(mle.estimate_cpd('tasty'))  # conditional"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`mle.estimate_cpd(variable)` computes the state counts and divides each cell by the (conditional) sample size. The `mle.get_parameters()`-method returns a list of CPDs for all variable of the model.\n",
    "\n",
    "The built-in `fit()`-method of `BayesianModel` provides more convenient access to parameter estimators:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calibrate all CPDs of `model` using MLE:\n",
    "model.fit(data, estimator=MaximumLikelihoodEstimator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "While very straightforward, the ML estimator has the problem of *overfitting* to the data. In above CPD, the probability of a large banana being tasty is estimated at `0.833`, because `5` out of `6` observed large bananas were tasty. Fine. But note that the probability of a small banana being tasty is estimated at `0.0`, because we  observed only one small banana and it happened to be not tasty. But that should hardly make us certain that small bananas aren't tasty!\n",
    "We simply do not have enough observations to rely on the observed frequencies. If the observed data is not representative for the underlying distribution, ML estimations will be extremly far off. \n",
    "\n",
    "When estimating parameters for Bayesian networks, lack of data is a frequent problem. Even if the total sample size is very large, the fact that state counts are done conditionally for each parents configuration causes immense fragmentation. If a variable has 3 parents that can each take 10 states, then state counts will be done seperately for `10^3 = 1000` parents configurations. This makes MLE very fragile and unstable for learning Bayesian Network parameters. A way to mitigate MLE's overfitting is *Bayesian Parameter Estimation*.\n",
    "\n",
    "#### Bayesian Parameter Estimation\n",
    "\n",
    "The Bayesian Parameter Estimator starts with already existing prior CPDs, that express our beliefs about the variables *before* the data was observed. Those \"priors\" are then updated, using the state counts from the observed data. See [1], Section 17.3 for a general introduction to Bayesian estimators.\n",
    "\n",
    "One can think of the priors as consisting in *pseudo state counts*, that are added to the actual counts before normalization.\n",
    "Unless one wants to encode specific beliefs about the distributions of the variables, one commonly chooses uniform priors, i.e. ones that deem all states equiprobable.\n",
    "\n",
    "A very simple prior is the so-called *K2* prior, which simply adds `1` to the count of every single state.\n",
    "A somewhat more sensible choice of prior is *BDeu* (Bayesian Dirichlet equivalent uniform prior). For BDeu we need to specify an *equivalent sample size* `N` and then the pseudo-counts are the equivalent of having observed `N` uniform samples of each variable (and each parent configuration). In pgmpy:\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+---------------------+--------------------+--------------------+---------------------+\n",
      "| fruit      | fruit(apple)        | fruit(apple)       | fruit(banana)      | fruit(banana)       |\n",
      "+------------+---------------------+--------------------+--------------------+---------------------+\n",
      "| size       | size(large)         | size(small)        | size(large)        | size(small)         |\n",
      "+------------+---------------------+--------------------+--------------------+---------------------+\n",
      "| tasty(no)  | 0.34615384615384615 | 0.4090909090909091 | 0.2647058823529412 | 0.6428571428571429  |\n",
      "+------------+---------------------+--------------------+--------------------+---------------------+\n",
      "| tasty(yes) | 0.6538461538461539  | 0.5909090909090909 | 0.7352941176470589 | 0.35714285714285715 |\n",
      "+------------+---------------------+--------------------+--------------------+---------------------+\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.estimators import BayesianEstimator\n",
    "est = BayesianEstimator(model, data)\n",
    "\n",
    "print(est.estimate_cpd('tasty', prior_type='BDeu', equivalent_sample_size=10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The estimated values in the CPDs are now more conservative. In particular, the estimate for a small banana being not tasty is now around `0.64` rather than `1.0`. Setting `equivalent_sample_size` to `10` means that for each parent configuration, we add the equivalent of 10 uniform samples (here: `+5` small bananas that are tasty and `+5` that aren't).\n",
    "\n",
    "`BayesianEstimator`, too, can be used via the `fit()`-method. Full example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------+----------+\n",
      "| A(0) | 0.503996 |\n",
      "+------+----------+\n",
      "| A(1) | 0.496004 |\n",
      "+------+----------+\n",
      "+------+-------------------+---------------------+\n",
      "| A    | A(0)              | A(1)                |\n",
      "+------+-------------------+---------------------+\n",
      "| B(0) | 0.499207135777998 | 0.49838872104733134 |\n",
      "+------+-------------------+---------------------+\n",
      "| B(1) | 0.500792864222002 | 0.5016112789526687  |\n",
      "+------+-------------------+---------------------+\n",
      "+------+---------------------+--------------------+--------------------+-------------------+\n",
      "| A    | A(0)                | A(0)               | A(1)               | A(1)              |\n",
      "+------+---------------------+--------------------+--------------------+-------------------+\n",
      "| D    | D(0)                | D(1)               | D(0)               | D(1)              |\n",
      "+------+---------------------+--------------------+--------------------+-------------------+\n",
      "| C(0) | 0.5066810768323836  | 0.4908018396320736 | 0.4844929606202816 | 0.511135414595347 |\n",
      "+------+---------------------+--------------------+--------------------+-------------------+\n",
      "| C(1) | 0.49331892316761644 | 0.5091981603679264 | 0.5155070393797184 | 0.488864585404653 |\n",
      "+------+---------------------+--------------------+--------------------+-------------------+\n",
      "+------+---------------------+---------------------+\n",
      "| B    | B(0)                | B(1)                |\n",
      "+------+---------------------+---------------------+\n",
      "| D(0) | 0.49959943921490085 | 0.49840542156667333 |\n",
      "+------+---------------------+---------------------+\n",
      "| D(1) | 0.5004005607850991  | 0.5015945784333267  |\n",
      "+------+---------------------+---------------------+\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pgmpy.models import BayesianModel\n",
    "from pgmpy.estimators import BayesianEstimator\n",
    "\n",
    "# generate data\n",
    "data = pd.DataFrame(np.random.randint(low=0, high=2, size=(5000, 4)), columns=['A', 'B', 'C', 'D'])\n",
    "model = BayesianModel([('A', 'B'), ('A', 'C'), ('D', 'C'), ('B', 'D')])\n",
    "\n",
    "model.fit(data, estimator=BayesianEstimator, prior_type=\"BDeu\") # default equivalent_sample_size=5\n",
    "for cpd in model.get_cpds():\n",
    "    print(cpd)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Structure Learning\n",
    "\n",
    "To learn model structure (a DAG) from a data set, there are two broad techniques:\n",
    "\n",
    " - score-based structure learning\n",
    " - constraint-based structure learning\n",
    "\n",
    "The combination of both techniques allows further improvement:\n",
    " - hybrid structure learning\n",
    "\n",
    "We briefly discuss all approaches and give examples.\n",
    "\n",
    "### Score-based Structure Learning\n",
    "\n",
    "\n",
    "This approach construes model selection as an optimization task. It has two building blocks:\n",
    "\n",
    "- A _scoring function_ $s_D\\colon M \\to \\mathbb R$ that maps models to a numerical score, based on how well they fit to a given data set $D$.\n",
    "- A _search strategy_ to traverse the search space of possible models $M$ and select a model with optimal score.\n",
    "\n",
    "\n",
    "#### Scoring functions\n",
    "\n",
    "Commonly used scores to measure the fit between model and data are _Bayesian Dirichlet scores_ such as *BDeu* or *K2* and the _Bayesian Information Criterion_ (BIC, also called MDL). See [1], Section 18.3 for a detailed introduction on scores. As before, BDeu is dependent on an equivalent sample size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-13937.875339249586\n",
      "-14328.68417117677\n",
      "-14293.914012166675\n",
      "-20904.92154575061\n",
      "-20931.743864925047\n",
      "-20948.966605351194\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pgmpy.estimators import BDeuScore, K2Score, BicScore\n",
    "from pgmpy.models import BayesianModel\n",
    "\n",
    "# create random data sample with 3 variables, where Z is dependent on X, Y:\n",
    "data = pd.DataFrame(np.random.randint(0, 4, size=(5000, 2)), columns=list('XY'))\n",
    "data['Z'] = data['X'] + data['Y']\n",
    "\n",
    "bdeu = BDeuScore(data, equivalent_sample_size=5)\n",
    "k2 = K2Score(data)\n",
    "bic = BicScore(data)\n",
    "\n",
    "model1 = BayesianModel([('X', 'Z'), ('Y', 'Z')])  # X -> Z <- Y\n",
    "model2 = BayesianModel([('X', 'Z'), ('X', 'Y')])  # Y <- X -> Z\n",
    "\n",
    "\n",
    "print(bdeu.score(model1))\n",
    "print(k2.score(model1))\n",
    "print(bic.score(model1))\n",
    "\n",
    "print(bdeu.score(model2))\n",
    "print(k2.score(model2))\n",
    "print(bic.score(model2))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While the scores vary slightly, we can see that the correct `model1` has a much higher score than `model2`.\n",
    "Importantly, these scores _decompose_, i.e. they can be computed locally for each of the variables given their potential parents, independent of other parts of the network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-9191.675817737154\n",
      "-6993.847644065196\n",
      "-57.120187742958706\n"
     ]
    }
   ],
   "source": [
    "print(bdeu.local_score('Z', parents=[]))\n",
    "print(bdeu.local_score('Z', parents=['X']))\n",
    "print(bdeu.local_score('Z', parents=['X', 'Y']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Search strategies\n",
    "The search space of DAGs is super-exponential in the number of variables and the above scoring functions allow for local maxima. The first property makes exhaustive search intractable for all but very small networks, the second prohibits efficient local optimization algorithms to always find the optimal structure. Thus, identifiying the ideal structure is often not tractable. Despite these bad news, heuristic search strategies often yields good results.\n",
    "\n",
    "If only few nodes are involved (read: less than 5), `ExhaustiveSearch` can be used to compute the score for every DAG and returns the best-scoring one:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('X', 'Z'), ('Y', 'Z')]\n",
      "\n",
      "All DAGs by score:\n",
      "-14293.914012166677 [('X', 'Z'), ('Y', 'Z')]\n",
      "-14328.333029363766 [('X', 'Y'), ('Z', 'X'), ('Z', 'Y')]\n",
      "-14328.333029363768 [('Y', 'X'), ('Z', 'X'), ('Z', 'Y')]\n",
      "-14328.33302936377 [('Y', 'Z'), ('Y', 'X'), ('Z', 'X')]\n",
      "-14328.33302936377 [('X', 'Z'), ('Y', 'Z'), ('Y', 'X')]\n",
      "-14328.33302936377 [('X', 'Y'), ('X', 'Z'), ('Z', 'Y')]\n",
      "-14328.33302936377 [('X', 'Y'), ('X', 'Z'), ('Y', 'Z')]\n",
      "-16494.429647593526 [('X', 'Y'), ('Z', 'Y')]\n",
      "-16497.214254274455 [('Y', 'X'), ('Z', 'X')]\n",
      "-18745.666363243407 [('Z', 'X'), ('Z', 'Y')]\n",
      "-18745.66636324341 [('Y', 'Z'), ('Z', 'X')]\n",
      "-18745.666363243414 [('X', 'Z'), ('Z', 'Y')]\n",
      "-20911.76298147317 [('Z', 'Y')]\n",
      "-20911.76298147317 [('Y', 'Z')]\n",
      "-20914.5475881541 [('Z', 'X')]\n",
      "-20914.5475881541 [('X', 'Z')]\n",
      "-20946.18199867026 [('Y', 'X'), ('Z', 'Y')]\n",
      "-20946.181998670265 [('Y', 'Z'), ('Y', 'X')]\n",
      "-20946.181998670265 [('X', 'Y'), ('Y', 'Z')]\n",
      "-20948.96660535119 [('X', 'Y'), ('Z', 'X')]\n",
      "-20948.966605351194 [('X', 'Z'), ('Y', 'X')]\n",
      "-20948.966605351194 [('X', 'Y'), ('X', 'Z')]\n",
      "-23080.64420638386 []\n",
      "-23115.063223580953 [('Y', 'X')]\n",
      "-23115.063223580953 [('X', 'Y')]\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.estimators import ExhaustiveSearch\n",
    "\n",
    "es = ExhaustiveSearch(data, scoring_method=bic)\n",
    "best_model = es.estimate()\n",
    "print(best_model.edges())\n",
    "\n",
    "print(\"\\nAll DAGs by score:\")\n",
    "for score, dag in reversed(es.all_scores()):\n",
    "    print(score, dag.edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once more nodes are involved, one needs to switch to heuristic search. `HillClimbSearch` implements a greedy local search that starts from the DAG `start` (default: disconnected DAG) and proceeds by iteratively performing single-edge manipulations that maximally increase the score. The search terminates once a local maximum is found.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('A', 'B'), ('A', 'C'), ('B', 'C'), ('G', 'A'), ('G', 'H'), ('H', 'A')]\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.estimators import HillClimbSearch\n",
    "\n",
    "# create some data with dependencies\n",
    "data = pd.DataFrame(np.random.randint(0, 3, size=(2500, 8)), columns=list('ABCDEFGH'))\n",
    "data['A'] += data['B'] + data['C']\n",
    "data['H'] = data['G'] - data['A']\n",
    "\n",
    "hc = HillClimbSearch(data, scoring_method=BicScore(data))\n",
    "best_model = hc.estimate()\n",
    "print(best_model.edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The search correctly identifies e.g. that `B` and `C` do not influnce `H` directly, only through `A` and of course that `D`, `E`, `F` are independent.\n",
    "\n",
    "\n",
    "To enforce a wider exploration of the search space, the search can be enhanced with a tabu list. The list keeps track of the last `n` modfications; those are then not allowed to be reversed, regardless of the score. Additionally a `white_list` or `black_list` can be supplied to restrict the search to a particular subset or to exclude certain edges. The parameter `max_indegree` allows to restrict the maximum number of parents for each node.\n",
    "\n",
    "\n",
    "### Constraint-based Structure Learning\n",
    "\n",
    "A different, but quite straightforward approach to build a DAG from data is this:\n",
    "\n",
    "1. Identify independencies in the data set using hypothesis tests \n",
    "2. Construct DAG (pattern) according to identified independencies\n",
    "\n",
    "#### (Conditional) Independence Tests\n",
    "\n",
    "Independencies in the data can be identified using chi2 conditional independence tests. To this end, constraint-based estimators in pgmpy have a `test_conditional_independence(X, Y, Zs)`-method, that performs a hypothesis test on the data sample. It allows to check if `X` is independent from `Y` given a set of variables `Zs`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n",
      "True\n",
      "True\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.estimators import ConstraintBasedEstimator\n",
    "\n",
    "data = pd.DataFrame(np.random.randint(0, 3, size=(2500, 8)), columns=list('ABCDEFGH'))\n",
    "data['A'] += data['B'] + data['C']\n",
    "data['H'] = data['G'] - data['A']\n",
    "data['E'] *= data['F']\n",
    "\n",
    "est = ConstraintBasedEstimator(data)\n",
    "\n",
    "print(est.test_conditional_independence('B', 'H'))          # dependent\n",
    "print(est.test_conditional_independence('B', 'E'))          # independent\n",
    "print(est.test_conditional_independence('B', 'H', ['A']))   # independent\n",
    "print(est.test_conditional_independence('A', 'G'))          # independent\n",
    "print(est.test_conditional_independence('A', 'G',  ['H']))  # dependent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`test_conditional_independence()` returns a tripel `(chi2, p_value, sufficient_data)`, consisting in the computed chi2 test statistic, the `p_value` of the test, and a heuristig flag that indicates if the sample size was sufficient. The `p_value` is the probability of observing the computed chi2 statistic (or an even higher chi2 value), given the null hypothesis that X and Y are independent given Zs.\n",
    "\n",
    "This can be used to make independence judgements, at a given level of significance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n",
      "True\n",
      "True\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "def is_independent(X, Y, Zs=[], significance_level=0.05):\n",
    "    return est.test_conditional_independence(X, Y, Zs)\n",
    "\n",
    "print(is_independent('B', 'H'))\n",
    "print(is_independent('B', 'E'))\n",
    "print(is_independent('B', 'H', ['A']))\n",
    "print(is_independent('A', 'G'))\n",
    "print(is_independent('A', 'G', ['H']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### DAG (pattern) construction\n",
    "\n",
    "With a method for independence testing at hand, we can construct a DAG from the data set in three steps:\n",
    "1. Construct an undirected skeleton - `estimate_skeleton()`\n",
    "2. Orient compelled edges to obtain partially directed acyclid graph (PDAG; I-equivalence class of DAGs) - `skeleton_to_pdag()`\n",
    "3. Extend DAG pattern to a DAG by conservatively orienting the remaining edges in some way - `pdag_to_dag()`\n",
    "\n",
    "Step 1.&2. form the so-called PC algorithm, see [2], page 550. PDAGs are `DirectedGraph`s, that may contain both-way edges, to indicate that the orientation for the edge is not determined.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Undirected edges:  [('A', 'B'), ('A', 'C'), ('A', 'H'), ('E', 'F'), ('G', 'H')]\n",
      "PDAG edges:        [('A', 'H'), ('B', 'A'), ('C', 'A'), ('E', 'F'), ('F', 'E'), ('G', 'H')]\n",
      "DAG edges:         [('A', 'H'), ('B', 'A'), ('C', 'A'), ('F', 'E'), ('G', 'H')]\n"
     ]
    }
   ],
   "source": [
    "skel, seperating_sets = est.estimate_skeleton(significance_level=0.01)\n",
    "print(\"Undirected edges: \", skel.edges())\n",
    "\n",
    "pdag = est.skeleton_to_pdag(skel, seperating_sets)\n",
    "print(\"PDAG edges:       \", pdag.edges())\n",
    "\n",
    "model = est.pdag_to_dag(pdag)\n",
    "print(\"DAG edges:        \", model.edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `estimate()`-method provides a shorthand for the three steps above and directly returns a `BayesianModel`:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('A', 'H'), ('B', 'A'), ('C', 'A'), ('F', 'E'), ('G', 'H')]\n"
     ]
    }
   ],
   "source": [
    "print(est.estimate(significance_level=0.01).edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `estimate_from_independencies()`-method can be used to construct a `BayesianModel` from a provided *set of independencies* (see class documentation for further features & methods):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('A', 'D'), ('B', 'D'), ('C', 'D')]\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.independencies import Independencies\n",
    "\n",
    "ind = Independencies(['B', 'C'],\n",
    "                     ['A', ['B', 'C'], 'D'])\n",
    "ind = ind.closure()  # required (!) for faithfulness\n",
    "\n",
    "model = ConstraintBasedEstimator.estimate_from_independencies(\"ABCD\", ind)\n",
    "\n",
    "print(model.edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PC PDAG construction is only guaranteed to work under the assumption that the identified set of independencies is *faithful*, i.e. there exists a DAG that exactly corresponds to it. Spurious dependencies in the data set can cause the reported independencies to violate faithfulness. It can happen that the estimated PDAG does not have any faithful completions (i.e. edge orientations that do not introduce new v-structures). In that case a warning is issued.\n",
    "\n",
    "\n",
    "### Hybrid Structure Learning\n",
    "\n",
    "The MMHC algorithm [3] combines the constraint-based and score-based method. It has two parts:\n",
    "\n",
    "1. Learn undirected graph skeleton using the constraint-based construction procedure MMPC\n",
    "2. Orient edges using score-based optimization (BDeu score + modified hill-climbing)\n",
    "\n",
    "We can perform the two steps seperately, more or less as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ G | ['B', 'C', 'H']. At least 4860 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ F | ['B', 'C', 'H']. At least 4860 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ E | ['B', 'C', 'H']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ D | ['B', 'C', 'H']. At least 4860 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ F | ['B', 'H', 'G']. At least 4860 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ F | ['C', 'H', 'G']. At least 4860 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ F | ['B', 'C', 'H', 'G']. At least 14580 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ E | ['B', 'H', 'G']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ E | ['C', 'H', 'G']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ E | ['B', 'C', 'H', 'G']. At least 21870 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ D | ['B', 'H', 'G']. At least 4860 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ D | ['C', 'H', 'G']. At least 4860 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing A _|_ D | ['B', 'C', 'H', 'G']. At least 14580 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing B _|_ G | ['A', 'H', 'C']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing B _|_ F | ['A', 'H', 'C']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing B _|_ E | ['A', 'H', 'C']. At least 5670 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing B _|_ D | ['A', 'H', 'C']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing C _|_ G | ['A', 'H', 'B']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing C _|_ F | ['A', 'H', 'B']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing C _|_ E | ['A', 'H', 'B']. At least 5670 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing C _|_ D | ['A', 'H', 'B']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing D _|_ A | ['E', 'H', 'B']. At least 6480 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing E _|_ H | ['F', 'C', 'D']. At least 3240 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing E _|_ A | ['F', 'C', 'H']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing E _|_ A | ['F', 'D', 'H']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing E _|_ A | ['C', 'D', 'H']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing E _|_ A | ['F', 'C', 'D', 'H']. At least 21870 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing E _|_ B | ['F', 'C', 'D', 'H']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing E _|_ G | ['F', 'C', 'D', 'H']. At least 7290 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing F _|_ H | ['E', 'C', 'D']. At least 2880 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing G _|_ B | ['H', 'C', 'A']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing G _|_ F | ['H', 'C', 'A']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing G _|_ E | ['H', 'C', 'A']. At least 5670 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing G _|_ D | ['H', 'C', 'A']. At least 3780 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing H _|_ C | ['G', 'B', 'A']. At least 5040 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing H _|_ F | ['G', 'B', 'A']. At least 5040 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing H _|_ E | ['G', 'A']. At least 2520 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing H _|_ E | ['B', 'A']. At least 2520 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing H _|_ E | ['G', 'B', 'A']. At least 7560 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n",
      "/home/ankur/pgmpy_notebook/notebooks/pgmpy/estimators/CITests.py:95: UserWarning: Insufficient data for testing H _|_ D | ['G', 'B', 'A']. At least 5040 samples recommended, 2500 present.\n",
      "  5 * num_params, len(data)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Part 1) Skeleton:  [('A', 'C'), ('A', 'H'), ('B', 'C'), ('E', 'F'), ('G', 'H')]\n",
      "Part 2) Model:     [('A', 'H'), ('C', 'A'), ('F', 'E'), ('G', 'H')]\n"
     ]
    }
   ],
   "source": [
    "from pgmpy.estimators import MmhcEstimator\n",
    "from pgmpy.estimators import BDeuScore\n",
    "\n",
    "data = pd.DataFrame(np.random.randint(0, 3, size=(2500, 8)), columns=list('ABCDEFGH'))\n",
    "data['A'] += data['B'] + data['C']\n",
    "data['H'] = data['G'] - data['A']\n",
    "data['E'] *= data['F']\n",
    "\n",
    "mmhc = MmhcEstimator(data)\n",
    "skeleton = mmhc.mmpc()\n",
    "print(\"Part 1) Skeleton: \", skeleton.edges())\n",
    "\n",
    "# use hill climb search to orient the edges:\n",
    "hc = HillClimbSearch(data, scoring_method=BDeuScore(data))\n",
    "model = hc.estimate(tabu_length=10, white_list=skeleton.to_directed().edges())\n",
    "print(\"Part 2) Model:    \", model.edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MmhcEstimator.estimate()` is a shorthand for both steps and directly estimates a `BayesianModel`.\n",
    "\n",
    "### Conclusion\n",
    "\n",
    "This notebook aimed to give an overview of pgmpy's estimators for learning Bayesian network structure and parameters. For more information about the individual functions see their docstring documentation. If you used pgmpy's structure learning features to satisfactorily learn a non-trivial network from real data, feel free to drop us an eMail via the mailing list or just open a Github issue. We'd like to put your network in the examples-section!\n",
    "\n",
    "### References\n",
    "\n",
    "[1] Koller & Friedman, Probabilistic Graphical Models - Principles and Techniques, 2009\n",
    "\n",
    "[2] Neapolitan, [Learning Bayesian Networks](http://www.cs.technion.ac.il/~dang/books/Learning%20Bayesian%20Networks&#40;Neapolitan,%20Richard&#41;.pdf), 2003\n",
    "\n",
    "[3] Tsamardinos et al., [The max-min hill-climbing BN structure learning algorithm](http://www.dsl-lab.org/supplements/mmhc_paper/paper_online.pdf), 2005\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
